from sklearn import model_selection
from sklearn.tree import DecisionTreeClassifier
import numpy as np
from utils.mlmodel_util import get_model_info_sklearn
from utils.mlmodel_util import eval_metric
from utils.model_util import load_model
from utils.mlmodel_util import preprocess
import logging
from utils.format_util import dup_name_handler

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def predict(dataframe, para_list, output):
    # --------------------load the model-----------------
    model_save_path, label_list, feature_list = get_model_info_sklearn(para_list["model_id"])
    mlmodel = load_model(model_save_path)

    # --------------------make predictions----------------
    X = preprocess(dataframe, para_list, feature_list)
    pred_raw = mlmodel.predict(X)
    pred = [label_list[p] for p in pred_raw]

    output_cols = "_classification_"
    output["result"]["output_params"]["output_cols"] = output_cols
    output_cols = dup_name_handler(output_cols, para_list["feature_col"] + [para_list["label_col"]])
    dataframe[output_cols] = pred
    return dataframe


def get_tree(tree, feature_num):
    n_nodes = tree.node_count
    children_left = tree.children_left
    children_right = tree.children_right
    feature = tree.feature
    threshold = tree.threshold

    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
    X_plot = []
    parent_list = np.zeros(n_nodes)
    as_left = np.zeros(n_nodes)
    as_right = np.zeros(n_nodes)
    parent_list[0] = -1
    while len(stack) > 0:
        # `pop` ensures each node is only visited once
        node_id, depth = stack.pop()
        node_depth[node_id] = depth

        # If the left and right child of a node is not the same we have a split
        # node
        is_split_node = children_left[node_id] != children_right[node_id]
        # If a split node, append left and right children and depth to `stack`
        # so we can loop through them

        if is_split_node:
            parent_list[children_left[node_id]] = node_id
            parent_list[children_right[node_id]] = node_id
            as_left[children_left[node_id]] = 1
            as_right[children_right[node_id]] = 1
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
        else:
            is_leaves[node_id] = True

    # logger.info("The binary tree structure has {n} nodes and has "
    #             "the following tree structure:\n".format(n=n_nodes))
    # for i in range(n_nodes):
    #     if is_leaves[i]:
    #         logger.info("{space}node={node} is a leaf node.".format(
    #             space=node_depth[i] * "\t", node=i))
    #     else:
    #         logger.info("{space}node={node} is a split node: "
    #                     "go to node {left} if X[:, {feature}] <= {threshold} "
    #                     "else to node {right}.".format(
    #             space=node_depth[i] * "\t",
    #             node=i,
    #             left=children_left[i],
    #             feature=feature[i],
    #             threshold=threshold[i],
    #             right=children_right[i]))
    result = {}
    result["n_nodes"] = n_nodes
    result["max_depth"] = max(list(node_depth)) + 1
    result["node_depth"] = list(node_depth)
    result["is_leaves"] = list(is_leaves)
    result["children_left"] = list(children_left)
    result["children_right"] = list(children_right)
    result["feature"] = list(feature)
    result["threshold"] = list(threshold)
    leave_index = []
    X_plot = []
    print(result)
    for i in range(n_nodes):
        if list(is_leaves)[i] == True:
            X = np.zeros(feature_num)
            parent = i
            while parent != -1:

                feature_index = list(feature)[int(parent_list[parent])]
                cur_threshold = list(threshold)[int(parent_list[parent])]
                if X[feature_index] == 0:
                    if as_left[parent]:
                        X[feature_index] = cur_threshold
                    else:
                        X[feature_index] = cur_threshold + 0.00001
                parent = int(parent_list[parent])
            X_plot.append(X)
            leave_index.append(i)

    return result, X_plot, leave_index


def train(dataframe, para_list, record):
    logging.info("Training Start")
    feature_list = []
    X = preprocess(dataframe, para_list, feature_list)
    feature_num = X.shape[1]

    label_col = para_list["label_col"]
    labels = dataframe[label_col].values
    label_list = list(np.unique(np.squeeze(labels)))
    label_list.sort()
    y = [label_list.index(y) for y in labels]
    X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y,
                                                                        test_size=para_list["split_rate"][1],
                                                                        random_state=0)

    # ---------------------model fit---------------------
    mlmodel = DecisionTreeClassifier(max_depth=para_list["max_depth"]).fit(X_train, y_train)
    logging.info("Model Fitted")
    # print(export_text(model,feature_names = feature_list))
    # ---------------------model validation---------------
    y_pred = mlmodel.predict(X_test)
    y_pred_res = mlmodel.predict(X)
    y_pred_all = y_pred
    y_test_all = y_test
    for label in label_list:
        if label not in y_pred:
            y_pred_all = mlmodel.predict(X)
            y_test_all = y
            break
    # ---------------------get metrics--------------------
    if len(label_list) < 1000:
        compute_confusion_matrix = 1
    else:
        compute_confusion_matrix = 0
    accuracy_test, precision_test, recall_test, F1_test, confusion_matrix = eval_metric(y_pred_all,
                                                                                        np.array(y_test_all),
                                                                                        compute_confusion_matrix)
    print(confusion_matrix)

    # ---------------------record info----------------------

    result, X_plot, leave_index = get_tree(mlmodel.tree_, feature_num)

    y_plot = mlmodel.predict(X_plot)
    classes = ["" for _ in range(result["n_nodes"])]
    j = 0
    for i in leave_index:
        classes[i] = str(label_list[y_plot[j]])
        j += 1
    result["classes"] = classes
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "label_list": label_list,
        "feature_importances": list(mlmodel.feature_importances_),
        "tree_structure": result,
        "accuracy_test": accuracy_test,
        "Fmeasure_test": F1_test,
        "precision_test": precision_test,
        "recall_test": recall_test,
        "confusion_matrix": confusion_matrix.tolist()
    }
    record["other_info"] = other_info

    y_pred_res = [label_list[p] for p in y_pred_res]
    output_cols = "_classification_"
    # output["result"]["output_params"]["output_cols"] = output_cols
    dataframe[output_cols] = y_pred_res

    return mlmodel, dataframe
